import os
import json
import base64
import cv2
import traceback
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import openai
from threading import Lock

# Initialize OpenAI Client
client = openai.OpenAI(
    api_key="",  # Replace with your valid API Key
    base_url="",
    timeout=2000,
    max_retries=3,
)

# Call GPT-4 API to generate scene expression
def gpt4(frames):
    try:
        result = client.chat.completions.create(
            messages=[{"role": "user", "content": frames}],
            model="gpt-4o"  # Correct model name
        )
        response = result.choices[0].message.content
        return response
    except json.JSONDecodeError:
        raise ValueError("Failed to parse the GPT response as JSON.")
    except Exception as e:
        raise RuntimeError(f"Error during GPT call: {e}")

# Prompt for different expression
gen_description = """
[Task]
You are given a video (multiple frames) capturing an indoor scene. Your goal is to recognize {categories_of_interest} objects, analyze the spatial layout of the scene, and describe the relative position of each object.

[Instructions]
1. Per-frame analysis:
- For each frame, choose one object as a **local reference point**.
- Predict the relative position of all other objects with respect to this local reference point.
- Express relative positions using simple terms like "left", "right", "front", "behind", "above", "below", and approximate distances (e.g., "2 meters to the right").

2. Global scene layout:
- Take the **local reference point from the first frame** as the **global reference point** for the whole video.
- Use overlapping objects between frames to align frames together.
- Gradually build the spatial descriptions for all objects relative to the global reference point.

[Rules]
- If a category has multiple instances (e.g., two chairs), describe each instance separately.
- Preserve the real-world spatial relationships and distances as accurately as possible.
- Use clear and consistent directional and distance terms.

[Output Format]
ONLY Return the result as a JSON dictionary following STRICTLY this format:
{
    "category name": "global reference point",
    "another category name": "Position description relative to the global reference point",
    ...
}

Example:
{
    "chair": "The chair is located 1.5 meters to the left and 0.5 meters behind the global reference point",
    "table": "The table is located 2 meters to the right of the global reference point",
    "lamp": "The lamp is located 1 meter above and 1 meter behind the global reference point"
}
"""

gen_3D_map = """
[Task]
You are given a **sequence of image frames** from an indoor video. Each frame captures a partial view of a 3D scene. Your task is to:

1. Detect and label the visible {categories_of_interest} objects in each frame.
2. For every frame:
   - Randomly choose one object as the **local reference origin** for that frame.
   - Assign it the coordinate (0, 0, 0).
   - Estimate the 3D coordinates [x, y, z] of all other objects in that frame, **relative to the local origin**.
3. For each frame (starting from frame 2 onward):
   - Estimate the **relative transformation matrix** from the previous frame to the current frame.
   - Express it as:
     {
       "rotation_matrix": [[r11, r12, r13], [r21, r22, r23], [r31, r32, r33]],
       "translation_vector": [dx, dy, dz]
     }
   - The rotation matrix should represent 3D rotation (e.g., derived from overlapping object positions), and the translation vector represents displacement in meters.
4. Track and unify object identities across frames (e.g., same table across multiple frames should be named consistently).
5. Eliminate duplicates when an object appears in more than one frame.

[Output Format]
Return ONLY a JSON dictionary with the following format:
{
  "frame_1": {
    "objects": {
      "object_A": [x, y, z],
      "object_B": [x, y, z]
    },
    "offset_to_prev": null
  },
  "frame_2": {
    "objects": {
      ...
    },
    "offset_to_prev": {
      "rotation_matrix": [[r11, r12, r13], [r21, r22, r23], [r31, r32, r33]],
      "translation_vector": [dx, dy, dz]
    }
  },
  "frame_3": {
    ...
  }
}

[Rules]
- Object positions must be expressed in meters.
- Object names must be consistent across frames (e.g., "chair_1", "table_2").
- Use approximate values based on visual cues.
- Only return the final JSON result. Do not include any explanation or preamble.

[Input]
You will receive a list of image frames in order.
Start with frame 1 as reference and compute subsequent positions and relative offsets.
"""

gen_2D_grid = """
[Task]
You are given a video capturing an indoor scene. Your objective is to identify and localize specific objects within the scene. The entire scene is projected onto a **10x10 2D grid**, where (0,0) is the bottom-left and (9,9) is the top-right corner. You are to estimate the **center point** of each object of interest within this grid space.

[Rules]
1. You should consider ONLY the following object categories in this task: {categories_of_interest}.
2. For each detected instance of these categories, estimate its center location as an (x, y) coordinate on the 10x10 grid.
3. If multiple instances of the same category exist, include **all of them** in the output.
4. Your estimations should **preserve the relative spatial layout** of the scene as accurately as possible. That is, if one object is clearly to the left or behind another, this should be reflected in their grid coordinates.
5. All coordinates must be expressed as **floating point numbers** between 0 and 9, accurate to at least one decimal place.
6. Do not include any categories or objects not listed above, even if they appear in the scene.

[Output Format]
Return ONLY a JSON dictionary with the following structure:
{
  "category_name_1": [(x1, y1), (x2, y2), ...],
  "category_name_2": [(x1, y1)],
  ...
}

Each key must match one of the specified categories, and each value is a list of estimated (x, y) center points for all visible instances of that category.

[Example Output]
{
  "chair": [(2.5, 3.0), (6.8, 3.2)],
  "table": [(4.1, 5.7)]
}

[Note]
- Do not return any explanation or text before or after the JSON result.
- Begin your response with the JSON object only.
"""

obj_list = json.load(open('object_list.json'))

# Extract base64-encoded frames from the video
def process_video(video_path, num_frames=8, resolution=(320, 240)):
    base64Frames = []

    # Add task description as text
    base64Frames.append({
        "type": "text",
        "text": gen_description.format(categories_of_interest=obj_list),
    })

    video = cv2.VideoCapture(video_path)
    if not video.isOpened():
        raise ValueError(f"Failed to open video file: {video_path}")

    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_interval = max(1, total_frames // num_frames)

    for i in range(num_frames):
        curr_frame = i * frame_interval
        video.set(cv2.CAP_PROP_POS_FRAMES, curr_frame)
        success, frame = video.read()
        if not success:
            break
        frame = cv2.resize(frame, resolution)
        _, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 50])
        image_a = base64.b64encode(buffer).decode("utf-8")

        base64Frames.append({
            "type": "image_url",
            "image_url": {"url": f"data:image/jpeg;base64,{image_a}"},
        })

    video.release()
    print(f"Extracted {len(base64Frames)-1} frames from {os.path.basename(video_path)}")
    return base64Frames

# Process a single video file
def process_single_video(data, filename, video_folder, output_dir, progress_bar, progress_lock):
    try:
        data_dir = os.path.join(video_folder, data)
        if filename.lower().endswith('.mp4'):
            VIDEO_PATH = os.path.join(data_dir, filename)
            prompt = process_video(VIDEO_PATH, num_frames=16, resolution=(480, 480))
            output = gpt4(prompt)

            output_name = filename.rsplit('.', 1)[0] + ".txt"
            output_path = os.path.join(output_dir, "object", output_name)

            os.makedirs(os.path.dirname(output_path), exist_ok=True)

            if os.path.exists(output_path):
                print(f"Found existing file: {output_path}")
            else:
                with open(output_path, 'w') as file:
                    file.write(output)

            with progress_lock:
                progress_bar.update(1)

    except Exception as e:
        print(f"Error processing {filename}: {e}")
        traceback.print_exc()

# Process all video files in a subdirectory
def process_videos_in_directory(data, video_folder, output_dir, progress_bar, progress_lock):
    data_dir = os.path.join(video_folder, data)
    video_files = [filename for filename in os.listdir(data_dir) if filename.lower().endswith('.mp4')]

    for filename in video_files:
        process_single_video(data, filename, video_folder, output_dir, progress_bar, progress_lock)

# Main entry point
def main():
    video_folder = 'VSI-Bench'
    dataset = ['arkitscenes', 'scannet', 'scannetpp']
    output_dir = 'output'

    total_files = sum(len([f for f in os.listdir(os.path.join(video_folder, data)) if f.lower().endswith('.mp4')]) for data in dataset)
    progress_lock = Lock()

    with tqdm(total=total_files, desc="Processing videos") as progress_bar:
        with ThreadPoolExecutor(max_workers=10) as executor:
            futures = []
            for data in dataset:
                futures.append(executor.submit(process_videos_in_directory, data, video_folder, output_dir, progress_bar, progress_lock))

            for future in futures:
                future.result()

if __name__ == '__main__':
    main()
